{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Activation Histograms\n",
    "\n",
    "This notebook shows an example of how to generate activation histograms for a specific model and dataset.\n",
    "\n",
    "## But I Already Know How To Generate Histograms...\n",
    "\n",
    "If you already generated histograms using Distiller outside this notebook, you can still use it to visualize the data:\n",
    "* To load the raw data saved by Distiller and visualize it, go to [this section](#Plot-Histograms)\n",
    "* If enabled saving histogram images and want to view them, go to [this section](#Load-Histogram-Images-from-Disk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import math\n",
    "import torchnet as tnt\n",
    "from ipywidgets import widgets, interact\n",
    "\n",
    "import distiller\n",
    "from distiller.models import create_model\n",
    "\n",
    "device = torch.device('cuda')\n",
    "# device = torch.device('cpu')\n",
    "\n",
    "# Load some common code and configure logging\n",
    "# We do this so we can see the logging output coming from\n",
    "# Distiller function calls\n",
    "%run './distiller_jupyter_helpers.ipynb'\n",
    "msglogger = config_notebooks_logger()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Your Model\n",
    "\n",
    "For this example we'll use a pre-trained image classification model.\n",
    "\n",
    "### Note on Parallelism\n",
    "\n",
    "Currently, Distiller's implementation of activations histograms collection does not accept models which contain [`DataParallel`](https://pytorch.org/docs/stable/nn.html?highlight=dataparallel#torch.nn.DataParallel) modules. So here we create the model without parallelism to begin with. If you have a model which includes `DataParallel` modules (for example, if loaded from a checkpoint), use the following utlity function to convert the model to serialized execution:\n",
    "```python\n",
    "model = distiller.utils.make_non_parallel_copy(model)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "model = create_model(pretrained=True, dataset='imagenet', arch='resnet18', parallel=False)\n",
    "model = model.to(device)  # Comment out if not applicable"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare Data\n",
    "\n",
    "Usually it is not required to collect histograms based on the entire dataset, and only a representative subset is used (that also helps reduce the runtime).\n",
    "* **Subset size:** There is no golden rule for selecting the size of the subset. Anywhere between 1-10% of the validation/test set should work.\n",
    "* **Representative data:** Whatever size is chosen, it is important to make sure that the subset is selected in a way that covers as much of the distribution of the data as possible. So, for example, if the dataset is organized by classes by default, we should make sure to select items randomly and not in order.\n",
    "\n",
    "**Note:** Working on only a subset of the data can be taken care of at data preparation time, or it can be delayed to the actual model evaluation function (for example, executing only a specific number of mini-batches). In this example we take care of it during data preparation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# We use Distiller's built-in data loading functionality for ImageNet,\n",
    "# which takes care of randomizing the data before selecting the subset.\n",
    "# While it creates train, validation and test data loaders, we're only\n",
    "# interested in the test dataset in this example.\n",
    "#\n",
    "# Subset size: Here we'll go with 1% of the test set, mostly for the\n",
    "# sake of speed. We control this with the 'effective_test_size' argument.\n",
    "#\n",
    "# We set the 'fixed_subset' argument to make sure we're using the\n",
    "# same subset for both phases of histogram collection - more on that below\n",
    "\n",
    "dataset = 'imagenet'\n",
    "dataset_path = '/datasets/imagenet'\n",
    "arch = 'resnet18'\n",
    "batch_size = 256\n",
    "num_workers = 10\n",
    "subset_size = 0.01\n",
    "\n",
    "_, _, test_loader, _ = distiller.apputils.load_data(\n",
    "    dataset, arch, os.path.expanduser(dataset_path), \n",
    "    batch_size, num_workers,\n",
    "    effective_test_size=subset_size, fixed_subset=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the Model Evaluation Function\n",
    "\n",
    "We define a fairly bare-bones evaluation function. Recording the loss and accuracy isn't strictly necessary for histogram collection. We record them nonetheless, so we can verify the data subset being used achieves results that are on par from what we'd expect from a representative subset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def eval_model(data_loader, model):\n",
    "    print('Evaluating model')\n",
    "    criterion = torch.nn.CrossEntropyLoss().to(device)\n",
    "    \n",
    "    loss = tnt.meter.AverageValueMeter()\n",
    "    classerr = tnt.meter.ClassErrorMeter(accuracy=True, topk=(1, 5))\n",
    "\n",
    "    total_samples = len(data_loader.sampler)\n",
    "    batch_size = data_loader.batch_size\n",
    "    total_steps = math.ceil(total_samples / batch_size)\n",
    "    print('{0} samples ({1} per mini-batch)'.format(total_samples, batch_size))\n",
    "\n",
    "    # Switch to evaluation mode\n",
    "    model.eval()\n",
    "\n",
    "    for step, (inputs, target) in enumerate(data_loader):\n",
    "        print('[{:3d}/{:3d}] ... '.format(step + 1, total_steps), end='', flush=True)\n",
    "        with torch.no_grad():\n",
    "            inputs, target = inputs.to(device), target.to(device)\n",
    "            # compute output from model\n",
    "            output = model(inputs)\n",
    "\n",
    "            # compute loss and measure accuracy\n",
    "            loss.add(criterion(output, target).item())\n",
    "            classerr.add(output.data, target)\n",
    "            \n",
    "            print('Top1: {:.3f}  Top5: {:.3f}  Loss: {:.3f}'.format(\n",
    "                classerr.value(1), classerr.value(5), loss.mean), flush=True)\n",
    "    print('----------')\n",
    "    print('Overall ==> Top1: {:.3f}  Top5: {:.3f}  Loss: {:.3f}'.format(\n",
    "        classerr.value(1), classerr.value(5), loss.mean), flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collect Histograms\n",
    "\n",
    "Histogram collection is implemented using Distiller's \"Collector\" mechanism, specifically in the `ActivationHistogramsCollector` class. It is stats-based, meaning it requires pre-computed min/max values per-tensor to be provided.\n",
    "\n",
    "The min/max stats are expected as a dictionary with the following structure:\n",
    "```YAML\n",
    "'layer_name':\n",
    "    'inputs':\n",
    "        0:\n",
    "            'min': value\n",
    "            'max': value\n",
    "        ...\n",
    "        n:\n",
    "            'min': value\n",
    "            'max': value\n",
    "    'output':\n",
    "        'min': value\n",
    "        'max': value\n",
    "```\n",
    "Where n is the number of inputs the layer has. The `QuantCalibrationStatsCollector` collector class generates stats in the required format.\n",
    "\n",
    "To streamline this process, a utility function is provided: `distiller.data_loggers.collect_histograms`. Given a model and a test function, it will perform the required stats collection followed by histograms collection. If the user has already computed min/max stats beforehand, those can provided as a dict or as a path to a YAML file (as saved by `QuantCalibrationStatsCollector`). In that case, the stats collection pass will be skipped.\n",
    "\n",
    "### Dataset Perparation in Context of Stats-Based Histograms\n",
    "\n",
    "If the data used for min/max stats collection is not the same as the data used for histogram collection, it is highly likely that when collecting histograms some values will fall outside the pre-calculated min/max range. When that happens, the value is **clamped**. Assuming the subsets of data used in both cases are representative enough, this shouldn't have a major effect on the results.\n",
    "\n",
    "One can choose to avoid this issue by making sure we use the same subset of data in both passes. How to make sure of that will, of course, differ from one use case to another. In this example we do this by using the enabling `fixed_subset` flag when calling `load_data` above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# The test function passed to 'collect_histograms' must have an \n",
    "# argument named 'model' which accepts the model for which histograms\n",
    "# are to be collected. 'collect_histograms' will not set any other\n",
    "# arguments.\n",
    "# We'll use Python's 'partial' to handle the set the rest of the\n",
    "# arguments for the test function before calling 'collect_histograms'\n",
    "from functools import partial\n",
    "test_fn = partial(eval_model, data_loader=test_loader)\n",
    "\n",
    "# Histogram collection parameters\n",
    "\n",
    "# 'save_dir': Pass a valid directory path to have the histogram\n",
    "#   data saved to disk. Pass None to disable saving.\n",
    "# 'save_hist_imgs': If save_dir is not None, toggles whether to save\n",
    "#   histogram images in addition to the raw data\n",
    "# 'hist_imgs_ext': Controls the filetype for histogram images\n",
    "save_dir = '.'\n",
    "save_hist_imgs = True\n",
    "hist_imgs_ext = '.png'\n",
    "\n",
    "# 'activation_stats': Here we pass None so a stats collection pass\n",
    "#   is performed.\n",
    "activation_stats = None\n",
    "\n",
    "# 'classes': To speed-up the calculation here we use the 'classes'\n",
    "#   argument so that stats and histograms are collected only for \n",
    "#   ReLU layers in the model. Pass None to collect for all layers.\n",
    "classes = [torch.nn.ReLU]\n",
    "\n",
    "# 'nbins': Number of histogram bins to use.\n",
    "nbins = 2048\n",
    "\n",
    "hist_dict = distiller.data_loggers.collect_histograms(\n",
    "    model, test_fn, save_dir=save_dir, activation_stats=activation_stats,\n",
    "    classes=classes, nbins=nbins, save_hist_imgs=save_hist_imgs, hist_imgs_ext=hist_imgs_ext)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot Histograms\n",
    "\n",
    "The generated dictionary has the following structure (very similar to the structure of the min/max stats dictionary described above):\n",
    "```yaml\n",
    "'layer_name':\n",
    "    'inputs':\n",
    "        0:\n",
    "            'hist': tensor             # Tensor with bin counts\n",
    "            'bin_centroids': tensor    # Tensor with activation values corresponding to center of each bin\n",
    "        ...\n",
    "        n:\n",
    "            'hist': tensor\n",
    "            'bin_centroids': tensor\n",
    "    'output':\n",
    "        'hist': tensor\n",
    "        'bin_centroids': tensor\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Uncomment this line to load saved output from a previous histogram collection run\n",
    "# hist_dict = torch.load('acts_histograms.pt')\n",
    "\n",
    "plt.style.use('seaborn') # pretty matplotlib plots\n",
    "\n",
    "def draw_hist(layer_name, tensor_name, bin_counts, bin_centroids, normed=True, yscale='linear'):\n",
    "    if normed:\n",
    "        bin_counts = bin_counts / bin_counts.sum()\n",
    "    plt.figure(figsize=(12, 6))\n",
    "    plt.title('\\n'.join((layer_name, tensor_name)), fontsize=16)\n",
    "    plt.fill_between(bin_centroids, bin_counts, step='mid', antialiased=False)\n",
    "    if yscale == 'linear':\n",
    "        plt.ylim(bottom=0)\n",
    "    plt.yscale(yscale)\n",
    "    plt.xlabel('Activation Value')\n",
    "    plt.ylabel('Normalized Count')\n",
    "    plt.show()\n",
    "\n",
    "@interact(layer_name=hist_dict.keys(),\n",
    "          normalize_bin_counts=True,\n",
    "          y_axis_scale=['linear', 'log'])\n",
    "def draw_layer(layer_name, normalize_bin_counts, y_axis_scale):\n",
    "    print('\\nSelected layer: ' + layer_name)\n",
    "    data = hist_dict[layer_name]\n",
    "    for idx, od in data['inputs'].items():\n",
    "        draw_hist(layer_name, 'input_{}'.format(idx), od['hist'], od['bin_centroids'],\n",
    "                  normed=normalize_bin_counts, yscale=y_axis_scale)\n",
    "    od = data['output']\n",
    "    draw_hist(layer_name, 'output', od['hist'], od['bin_centroids'],\n",
    "              normed=normalize_bin_counts, yscale=y_axis_scale)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Histogram Images from Disk\n",
    "\n",
    "If you enabled saving of histogram images above, or have images from a collection executed externally, you can use the code below to display the images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image, SVG, display\n",
    "import glob\n",
    "from collections import OrderedDict\n",
    "\n",
    "# Set the path to the images directory\n",
    "imgs_dir = 'histogram_imgs'\n",
    "\n",
    "files = sorted(glob.glob(os.path.join(imgs_dir, '*.*')))\n",
    "files = [f for f in files if os.path.isfile(f)]\n",
    "fnames_map = OrderedDict([(os.path.split(f)[1], f) for f in files])\n",
    "\n",
    "@interact(file_name=fnames_map)\n",
    "def load_image(file_name):\n",
    "    if file_name.endswith('.svg'):\n",
    "        display(SVG(filename=file_name))\n",
    "    else:\n",
    "        display(Image(filename=file_name))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
